This is an R Markdown Notebook. When you execute code within the notebook, the results appear beneath the code.
Try executing this chunk by clicking the Run button within the chunk or by placing your cursor inside it and pressing Ctrl+Shift+Enter.
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
library(data.table)
library(plotly)
##
## Attaching package: 'plotly'
## The following object is masked from 'package:ggplot2':
##
## last_plot
## The following object is masked from 'package:stats':
##
## filter
## The following object is masked from 'package:graphics':
##
## layout
# attach iris
data(iris)
dt <- iris
# obtain 80% of the data set for training
validation_index <- createDataPartition(dt$Species, p = 0.8, list = FALSE)
# select 20% for validation and 80% for training
dtValidation <- dt[-validation_index, ] %>% data.table()
dtTest <- dt[validation_index, ] %>% data.table()
# dimensions
dim(dtTest)
## [1] 120 5
# list attributes
sapply(dt, class)
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## "numeric" "numeric" "numeric" "numeric" "factor"
# peek at dataset
head(dt)
# list the levels for the class
levels(dt$Species)
## [1] "setosa" "versicolor" "virginica"
# class distribution of test data
percentage <- prop.table(table(dtTest$Species))*100
cbind(freq = table(dtTest$Species), percentage = percentage)
## freq percentage
## setosa 40 33.33333
## versicolor 40 33.33333
## virginica 40 33.33333
# statistical summary
summary(dtTest)
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## Min. :4.300 Min. :2.000 Min. :1.000 Min. :0.100
## 1st Qu.:5.100 1st Qu.:2.800 1st Qu.:1.600 1st Qu.:0.300
## Median :5.700 Median :3.000 Median :4.250 Median :1.300
## Mean :5.829 Mean :3.026 Mean :3.758 Mean :1.192
## 3rd Qu.:6.400 3rd Qu.:3.300 3rd Qu.:5.100 3rd Qu.:1.800
## Max. :7.900 Max. :4.100 Max. :6.900 Max. :2.500
## Species
## setosa :40
## versicolor:40
## virginica :40
##
##
##
# univariate plots: boxplot of each individual variable
dtPlot <- melt.data.table(dtTest, id.vars = "Species",variable.name = "Type", value.name = "Value",
measure.vars = c("Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"))
dtPlot$Normalized <- (dtPlot$Value - min(dtPlot$Value))/(max(dtPlot$Value) - min(dtPlot$Value))
plot_ly(dtPlot, y = ~Normalized, color = ~Type, type = "box", boxpoints = "all", jitter = 0.3)
# bar plot of class breakdown
plot_ly(dtTest[, .N, keyby = Species], y = ~N, x = ~Species, type = "bar")
# multivariate plots: feature plot of attributes and color by class, box plot for each attribute by species
control <- trainControl(method = "cv", number = 10)
metric <- "Accuracy"
# Linear Discriminant Analysis (LDA)
set.seed(7)
fit.lda <- train(Species ~ ., data = dtTest, method = "lda", metric = metric, trControl = control)
# Classification and Regression Trees (CART)
set.seed(7)
fit.cart <- train(Species ~ ., data = dtTest, method = "rpart", metric = metric, trControl = control)
# k-Nearest Neighbors (kNN)
set.seed(7)
fit.knn <- train(Species ~ ., data = dtTest, method = "knn", metric = metric, trControl = control)
# Support Vector Machines (SVM) with a linear kernel
set.seed(7)
fit.svm <- train(Species ~ ., data = dtTest, method = "svmRadial", metric = metric, trControl = control)
# Random Forest (RF)
set.seed(7)
fit.rf <- train(Species ~ ., data = dtTest, method = "rf", metric = metric, trControl = control)
# sumarize accuracy of the models
results <- resamples(list(lda = fit.lda, cart = fit.cart, knn = fit.knn, svm = fit.svm, rf = fit.rf))
summary(results)
##
## Call:
## summary.resamples(object = results)
##
## Models: lda, cart, knn, svm, rf
## Number of resamples: 10
##
## Accuracy
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## lda 0.9166667 0.9375000 1.0000000 0.9750000 1.0000000 1 0
## cart 0.8333333 0.9166667 0.9166667 0.9250000 0.9791667 1 0
## knn 0.9166667 0.9166667 1.0000000 0.9666667 1.0000000 1 0
## svm 0.9166667 0.9166667 0.9166667 0.9500000 1.0000000 1 0
## rf 0.9166667 0.9166667 0.9166667 0.9500000 1.0000000 1 0
##
## Kappa
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## lda 0.875 0.90625 1.000 0.9625 1.00000 1 0
## cart 0.750 0.87500 0.875 0.8875 0.96875 1 0
## knn 0.875 0.87500 1.000 0.9500 1.00000 1 0
## svm 0.875 0.87500 0.875 0.9250 1.00000 1 0
## rf 0.875 0.87500 0.875 0.9250 1.00000 1 0
# compare the accuracy of the models
dotplot(results)
# summarize the best model
print(fit.lda)
## Linear Discriminant Analysis
##
## 120 samples
## 4 predictor
## 3 classes: 'setosa', 'versicolor', 'virginica'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 108, 108, 108, 108, 108, 108, ...
## Resampling results:
##
## Accuracy Kappa
## 0.975 0.9625
# estimate skill of LDA
predictions <- predict(fit.lda, dtValidation)
confusionMatrix(predictions, dtValidation$Species)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 10 0
## virginica 0 0 10
##
## Overall Statistics
##
## Accuracy : 1
## 95% CI : (0.8843, 1)
## No Information Rate : 0.3333
## P-Value [Acc > NIR] : 4.857e-15
##
## Kappa : 1
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 1.0000 1.0000
## Specificity 1.0000 1.0000 1.0000
## Pos Pred Value 1.0000 1.0000 1.0000
## Neg Pred Value 1.0000 1.0000 1.0000
## Prevalence 0.3333 0.3333 0.3333
## Detection Rate 0.3333 0.3333 0.3333
## Detection Prevalence 0.3333 0.3333 0.3333
## Balanced Accuracy 1.0000 1.0000 1.0000